Skip to content

Enable MiMo-Audio-7B end-to-end inference on Intel XPU#2983

Draft
Liangyx2 wants to merge 3 commits into
vllm-project:mainfrom
Liangyx2:MiMoAudio
Draft

Enable MiMo-Audio-7B end-to-end inference on Intel XPU#2983
Liangyx2 wants to merge 3 commits into
vllm-project:mainfrom
Liangyx2:MiMoAudio

Conversation

@Liangyx2
Copy link
Copy Markdown
Contributor

PR Description

Motivation

MiMo-Audio (XiaomiMiMo) is a multi-modal audio model supporting TTS, voice cloning, audio transcription, and spoken dialogue. Currently it only runs on CUDA. This PR enables MiMo-Audio inference on XPU (Intel GPU) by adding platform-specific stage configs and fixing several CUDA-only code paths that prevented the model from loading and running on non-CUDA devices.

Technical Details

  1. XPU stage config (mimo_audio.yaml): Added a 2-stage pipeline config (Stage 0: fused_thinker_talker for LLM + audio code generation, Stage 1: code2wav for waveform synthesis) with XPU-specific knobs (enforce_eager, disable_hybrid_kv_cache_manager, skip_mm_profiling, memory utilization tuning).

  2. Guard CUDA-only APIs: Wrapped all torch.cuda.is_current_stream_capturing() calls in mimo_audio.py, mimo_audio_code2wav.py, and mimo_audio_llm.py with torch.cuda.is_available() and device.type == "cuda" checks, returning False on non-CUDA devices. This prevents runtime errors on XPU.

  3. Fix device-hardcoded defaults: Removed torch.device(f"cuda:{torch.cuda.current_device()}") defaults in mimo_audio_llm.py's generate_audio_tokens / generate_audio_tokens_one_step methods, replacing them with local_embeds.device to be device-agnostic.

  4. Fix multimodal processor: Added _hf_processor_applies_updates() -> False override in MiMoAudioLLMMultiModalProcessor so that vllm correctly applies prompt updates (audio placeholder expansion) instead of assuming the HF processor already did it.

  5. Robustness improvements in end2end.py:

    • Reference audio truncation to 8 seconds (MAX_REF_AUDIO_SAMPLES) to prevent model confusion (repetition / voice identity loss) with long clips.
    • Truncation of code2wav input tokens to MAX_CODE2WAV_TOKENS=8192 in the stage input processor to prevent OOM.
    • Skip invalid/empty audio outputs (< 10ms) instead of writing corrupt WAV files.
    • Default --text to None so each query type uses its own sensible default.

Performance Impact

  • No regression on CUDA — all changes are guarded by device-type checks.
  • Enables single-GPU XPU inference for MiMo-Audio with enforce_eager=true and conservative memory settings (gpu_memory_utilization: 0.4 / 0.35).
  • Reference audio truncation and code2wav token truncation improve stability and prevent OOM on memory-constrained XPU devices.

Workload Mapping

Workload Model Platform Config
MiMo-Audio TTS / Voice Cloning / Audio Understanding / Spoken Dialogue XiaomiMiMo/MiMo-Audio XPU (Intel GPU) mimo_audio.yaml (2-stage: fused_thinker_talker + code2wav)

@Liangyx2 Liangyx2 requested a review from hsliuustc0106 as a code owner April 21, 2026 07:44
@chatgpt-codex-connector
Copy link
Copy Markdown

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.
Credits must be used to enable repository wide code reviews.

Copy link
Copy Markdown
Collaborator

@hsliuustc0106 hsliuustc0106 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NON-BLOCKING:

  • Test Coverage — XPU is experimental and CI does not run on XPU hardware. Since this PR adds device-type guards and XPU-specific configuration, please verify manually on XPU that:

    1. The model loads successfully with the new mimo_audio.yaml config
    2. Inference produces valid audio output for at least one query type (e.g., tts_sft)
    3. No runtime errors from CUDA-specific APIs on XPU

    Consider adding a note in the PR description confirming which XPU configuration was tested.

num_reqs = len(request_ids)
is_capturing = torch.cuda.is_current_stream_capturing()
if torch.cuda.is_available() and input_ids.device.type == "cuda":
is_capturing = torch.cuda.is_current_stream_capturing()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't other platforms support torch.xxx.is_current_stream_capturing()?

@@ -0,0 +1,103 @@
# XPU stage config for running MiMo-Audio with 2-stage architecture
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't introduce the new stage configs. Please refer to #2383 and add a correct deploy config.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hsliuustc0106
Copy link
Copy Markdown
Collaborator

cc @qibaoyuan

@qibaoyuan
Copy link
Copy Markdown
Contributor

Thanks! Could you help us test this incoming PR on XPU and report any issues you encounter?
#2183

- Guard torch.cuda.is_current_stream_capturing() with device-type checks
  in mimo_audio.py, mimo_audio_code2wav.py, mimo_audio_llm.py to prevent
  RuntimeError on non-CUDA devices (XPU).
- Replace hardcoded CUDA device defaults in base_local_forward/local_forward
  with local_embeds.device for device-agnostic behavior.
- Add _hf_processor_applies_updates() override in MiMoAudioLLMMultiModalProcessor
  so vllm correctly applies prompt updates on XPU.
- Add platforms.xpu section in mimo_audio.yaml deploy config with
  XPU-specific knobs (enforce_eager, disable_hybrid_kv_cache_manager, etc).
- Fix end2end.py --text default from empty string to None so tts_sft
  correctly uses 'The weather is so nice today.' when no text is specified.
Liangyx2 added 2 commits June 2, 2026 16:01
- Stage 0: gpu_memory_utilization 0.5 -> 0.95, max_model_len 8192 -> 3000
- Stage 1: gpu_memory_utilization 0.35 -> 0.9
- Update device assignments: stage 0 on device 1, stage 1 on device 2

Fixes OOM error where model weights (16.17 GiB) exceeded the memory
budget allocated by gpu_memory_utilization on 24.5 GiB XPU devices.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants